{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([-3.9796828e-04,  1.3983431e+00, -4.0319901e-02, -5.5897278e-01,\n",
       "        4.6789032e-04,  9.1330688e-03,  0.0000000e+00,  0.0000000e+00],\n",
       "      dtype=float32)"
      ]
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import gym\n",
    "\n",
    "\n",
    "#定义环境\n",
    "class MyWrapper(gym.Wrapper):\n",
    "\n",
    "    def __init__(self):\n",
    "        env = gym.make('LunarLander-v2')\n",
    "        super().__init__(env)\n",
    "        self.env = env\n",
    "\n",
    "    def reset(self):\n",
    "        state, _ = self.env.reset()\n",
    "        return state\n",
    "\n",
    "    def step(self, action):\n",
    "        state, reward, done, _, info = self.env.step(action)\n",
    "        return state, reward, done, info\n",
    "\n",
    "\n",
    "env = MyWrapper()\n",
    "\n",
    "env.reset()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "id": "w7vOFlpA_ONz"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "env.observation_space= Box([-1.5       -1.5       -5.        -5.        -3.1415927 -5.\n",
      " -0.        -0.       ], [1.5       1.5       5.        5.        3.1415927 5.        1.\n",
      " 1.       ], (8,), float32)\n",
      "env.action_space= Discrete(4)\n",
      "state= [-0.00586414  1.4207945  -0.593993    0.43883783  0.0068019   0.13454841\n",
      "  0.          0.        ]\n",
      "action= 3\n",
      "next_state= [-0.01165171  1.4300904  -0.58351576  0.41312727  0.01151062  0.09418334\n",
      "  0.          0.        ]\n",
      "reward= 0.921825805844976\n",
      "done= False\n"
     ]
    }
   ],
   "source": [
    "#认识游戏环境\n",
    "def test_env():\n",
    "    print('env.observation_space=', env.observation_space)\n",
    "    print('env.action_space=', env.action_space)\n",
    "\n",
    "    state = env.reset()\n",
    "    action = env.action_space.sample()\n",
    "    next_state, reward, done, _ = env.step(action)\n",
    "\n",
    "    print('state=', state)\n",
    "    print('action=', action)\n",
    "    print('next_state=', next_state)\n",
    "    print('reward=', reward)\n",
    "    print('done=', done)\n",
    "\n",
    "\n",
    "test_env()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "id": "543OHYDfcjK4"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<stable_baselines3.ppo.ppo.PPO at 0x7f8e542914c0>"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from stable_baselines3.common.env_util import make_vec_env\n",
    "from stable_baselines3 import PPO\n",
    "\n",
    "#初始化模型\n",
    "model = PPO(\n",
    "    policy='MlpPolicy',\n",
    "    env=make_vec_env(MyWrapper, n_envs=4),  #创建N个环境用于训练\n",
    "    n_steps=1024,\n",
    "    batch_size=64,\n",
    "    n_epochs=4,\n",
    "    gamma=0.999,\n",
    "    gae_lambda=0.98,\n",
    "    ent_coef=0.01,\n",
    "    verbose=0)\n",
    "\n",
    "model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/root/anaconda3/envs/pt39/lib/python3.9/site-packages/stable_baselines3/common/evaluation.py:67: UserWarning: Evaluation environment is not wrapped with a ``Monitor`` wrapper. This may result in reporting modified episode lengths and rewards, if other wrappers happen to modify these. Consider wrapping environment first with ``Monitor`` wrapper.\n",
      "  warnings.warn(\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "(-236.2219022285426, 122.28696606172053)"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from stable_baselines3.common.evaluation import evaluate_policy\n",
    "\n",
    "#测试\n",
    "evaluate_policy(model, env, n_eval_episodes=10, deterministic=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "3dbc4f57035d451591303f3fee83271c",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Output()"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"></pre>\n"
      ],
      "text/plain": []
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
       "</pre>\n"
      ],
      "text/plain": [
       "\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "#训练\n",
    "model.learn(total_timesteps=20_0000, progress_bar=True)\n",
    "model.save('models/ppo-LunarLander-v2')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "id": "zpz8kHlt_a_m"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(45.54798891161171, 139.1048836822021)"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model = PPO.load('models/ppo-LunarLander-v2')\n",
    "\n",
    "evaluate_policy(model, env, n_eval_episodes=10, deterministic=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "id": "oj8PSGHJfwz3"
   },
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "45f3a5f31e464240bba8485ed3e659ef",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Downloading:   0%|          | 0.00/144k [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "== CURRENT SYSTEM INFO ==\n",
      "- OS: Linux-5.15.0-3.60.5.1.el9uek.x86_64-x86_64-with-glibc2.34 # 2 SMP Wed Oct 19 20:27:31 PDT 2022\n",
      "- Python: 3.9.15\n",
      "- Stable-Baselines3: 1.8.0a1\n",
      "- PyTorch: 1.13.0+cpu\n",
      "- GPU Enabled: False\n",
      "- Numpy: 1.23.5\n",
      "- Gym: 0.26.2\n",
      "\n",
      "== SAVED MODEL SYSTEM INFO ==\n",
      "OS: Linux-5.13.0-40-generic-x86_64-with-debian-bullseye-sid #45~20.04.1-Ubuntu SMP Mon Apr 4 09:38:31 UTC 2022\n",
      "Python: 3.7.10\n",
      "Stable-Baselines3: 1.5.1a5\n",
      "PyTorch: 1.11.0\n",
      "GPU Enabled: False\n",
      "Numpy: 1.21.2\n",
      "Gym: 0.21.0\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/root/anaconda3/envs/pt39/lib/python3.9/site-packages/stable_baselines3/common/evaluation.py:67: UserWarning: Evaluation environment is not wrapped with a ``Monitor`` wrapper. This may result in reporting modified episode lengths and rewards, if other wrappers happen to modify these. Consider wrapping environment first with ``Monitor`` wrapper.\n",
      "  warnings.warn(\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "(250.9974542026721, 86.61020518339575)"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from huggingface_sb3 import load_from_hub\n",
    "\n",
    "#!pip install huggingface-sb3\n",
    "\n",
    "#加载其他训练好的模型\n",
    "#https://huggingface.co/models?library=stable-baselines3\n",
    "model = PPO.load(\n",
    "    load_from_hub('araffin/ppo-LunarLander-v2', 'ppo-LunarLander-v2.zip'),\n",
    "    custom_objects={\n",
    "        'learning_rate': 0.0,\n",
    "        'lr_schedule': lambda _: 0.0,\n",
    "        'clip_range': lambda _: 0.0,\n",
    "    },\n",
    "    print_system_info=True,\n",
    ")\n",
    "\n",
    "evaluate_policy(model, env, n_eval_episodes=10, deterministic=False)"
   ]
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "collapsed_sections": [],
   "include_colab_link": true,
   "name": "Copie de Unit 1: Train your first Deep Reinforcement Learning Agent 🚀.ipynb",
   "private_outputs": true,
   "provenance": []
  },
  "gpuClass": "standard",
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.15"
  },
  "vscode": {
   "interpreter": {
    "hash": "ed7f8024e43d3b8f5ca3c5e1a8151ab4d136b3ecee1e3fd59e0766ccc55e1b10"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
